"""
Date: 2024-11-07 07:02:20
LastEditors: djw
LastEditTime: 2024-12-10 08:48:32
"""

import torch
from torch import nn
import queue
import signal
import queue
from typing import AsyncIterable
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from contextlib import asynccontextmanager
from pydantic import BaseModel, Field
import asyncio
import multiprocessing
import time
import torch.multiprocessing as mp
import random
import torch.distributed as dist
import zmq
import tempfile
from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput

from ktransformers.server.config.config import Config
from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
from ktransformers.server.balance_serve.settings import sched_ext



def pad_num_tokens(num_tokens):
    return (num_tokens + 63) // 64 * 64

def deduplicate_and_sort(lst):
    return sorted(set(lst))
class ModelRunner:
    """A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile."""

    model: KDeepseekV3ForCausalLM
    input: ForwardBatchInput | list[ForwardBatchInput]
    output: ForwardBatchOutput
    
    def __init__(self, model = None, device = None, use_cuda_graph = False, max_decode_batch_size = 1, max_chunk_size = 4096, num_mini_batches: int = 1, page_size = 256):
        
        self.stream = torch.cuda.Stream(device=device)
        # 先注释掉
        self.model = model  # Compile and move model to the specified device
        self.device = device
        self.input = None
        self.features_buf = None
        self.output = None
        self.graph_memory_pool = None
        self.cuda_graphs = deduplicate_and_sort([1, 2, 3, Config().max_batch_size, 64, Config().chunk_size])
        self.use_cuda_graph = use_cuda_graph
        self.model_time = 0
        self.page_size = page_size
        # GPU timing for model execution
        self.start_model_event = torch.cuda.Event(enable_timing=True)
        self.end_model_event = torch.cuda.Event(enable_timing=True)
        if isinstance(self.cuda_graphs, list):
            self.graphs = [torch.cuda.CUDAGraph() for _ in range(len(self.cuda_graphs))]
            self.page_idx_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))]
            self.page_offset_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))]
        else:
            self.graphs = torch.cuda.CUDAGraph()
            self.page_idx_buf = torch.zeros([self.cuda_graphs], dtype=torch.int32, device = self.device)
            self.page_offset_buf = torch.zeros([self.cuda_graphs], dtype=torch.int32, device = self.device)
        self.num_mini_batches = num_mini_batches

        self.max_chunk_size = max_chunk_size

        self.bsz_tensor_buf = torch.empty((1, ),dtype=torch.int32, device=device)
        self.num_tokens_tensor_buf = torch.empty((1, ),dtype=torch.int32, device=device)
    def warmup(self):

        def capture_graphs(cuda_graph_idx=-1):
            if cuda_graph_idx != -1:
                with torch.cuda.graph(self.graphs[cuda_graph_idx], pool=self.graph_memory_pool, stream=self.stream):
                    self.outputs_buf[cuda_graph_idx] = self.model(self.input[cuda_graph_idx], self.features_buf[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[cuda_graph_idx], self.page_offset_buf[cuda_graph_idx], cuda_graph_idx=cuda_graph_idx)   
                self.graph_memory_pool = self.graphs[cuda_graph_idx].pool()
            else:
                with torch.cuda.graph(self.graphs, pool=self.graph_memory_pool, stream=self.stream):
                    self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf)   
                self.graph_memory_pool = self.graphs.pool()

        if isinstance(self.cuda_graphs, list):
            self.input = []
            self.features_buf = []
            self.outputs_buf = []
            self.bsz_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device)
            self.num_tokens_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device)
            for i in range(len(self.cuda_graphs)):
                prefill_query_length = (self.cuda_graphs[i] - Config().max_decode_batch_size) // Config().max_prefill_batch_size if self.cuda_graphs[i] > Config().max_decode_batch_size else 0  #@TODO only supprot 2 prefill batch
                self.input.append(ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches, prefill_query_length=prefill_query_length, prefill_active_length=prefill_query_length, page_size=self.page_size, cuda_lens = self.cuda_graphs[i]))

                self.features_buf.append(self.model.batch_embeddings(self.input[i]))
                batch_size = self.input[i].minibatch.q_indptr.size(0)-1
                num_tokens = self.features_buf[i][0].size(0)
                print("capturing cuda graph", batch_size, num_tokens)
                self.bsz_tensor_buf[0] = batch_size
                self.num_tokens_tensor_buf[0] = num_tokens

                self.model.flash_infer_attn_plan(self.input[i], self.bsz_tensor_buf, self.num_tokens_tensor_buf,
                                                num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, 
                                                    head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
                                                    sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
            
                page_idx, page_offset = self.model.cache.get_page_table(self.input[i].minibatch.position_ids, self.input[i].minibatch.q_indptr, self.input[i].minibatch.kv_indptr, self.input[i].minibatch.kv_indices, self.num_tokens_tensor_buf)

                self.page_idx_buf[i][:num_tokens].copy_(page_idx[:num_tokens])
                self.page_offset_buf[i][:num_tokens].copy_(page_offset[:num_tokens])
                self.page_idx_buf[i][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size -1) 
                
                self.outputs_buf.append(None)
            
                torch.cuda.synchronize()
                for warm_up_iters in range(11):
                    with torch.cuda.stream(self.stream):
                        self.outputs_buf[i] = self.model(self.input[i], self.features_buf[i], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[i], self.page_offset_buf[i])
                torch.cuda.synchronize()

                capture_graphs(i)

                with torch.cuda.stream(self.stream):
                    self.graphs[i].replay()

                self.sync(calc_time=False)
                print(f"cuda_graph: {i+1}/{len(self.cuda_graphs)}, warmup finished.")
        else:
            self.input = ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches)

            self.features_buf = self.model.batch_embeddings(self.input)
            batch_size = self.input.minibatch.q_indptr.size(0)-1
            num_tokens = self.features_buf[0].size(0)

            
            self.bsz_tensor_buf = torch.tensor([batch_size], dtype=torch.int32, device=self.device)
            self.num_tokens_tensor_buf = torch.tensor([num_tokens], dtype=torch.int32, device=self.device)


            self.model.flash_infer_attn_plan(self.input, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
                                            num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, 
                                                head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
                                                sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
            
            page_idx, page_offset = self.model.cache.get_page_table(self.input.minibatch.position_ids, self.input.minibatch.q_indptr, self.input.minibatch.kv_indptr, self.input.minibatch.kv_indices, self.num_tokens_tensor_buf)
            self.page_idx_buf[:num_tokens].copy_(page_idx[:num_tokens])
            self.page_offset_buf[:num_tokens].copy_(page_offset[:num_tokens])
            self.page_idx_buf[num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1) 
            
            
            torch.cuda.synchronize()
            for warm_up_iters in range(11):
                with torch.cuda.stream(self.stream):
                    self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf)
            torch.cuda.synchronize()

            def capture_graphs():
                with torch.cuda.graph(self.graphs,  stream=self.stream):
                    self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf)   
                    # self.graph_memory_pool = self.graphs.pool()


            capture_graphs()

            with torch.cuda.stream(self.stream):
                self.graphs.replay()

            self.sync(calc_time=False)
            print("warmup finished.")
        
    def run(self, batch: sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None):
        with torch.cuda.stream(self.stream):

            batch_size = len(batch.prefill_mini_batches) # TODO: calc this
            num_tokens = 0
            for i in range(len(batch.decode_mini_batches)):
                batch_size += len(batch.decode_mini_batches[i])
                num_tokens += len(batch.decode_mini_batches[i])
                print(f'decode_batch_i: {len(batch.decode_mini_batches[i])},')

            for i in range(len(batch.prefill_mini_batches)):
                num_tokens += batch.prefill_mini_batches[i][2]
                print(f'prefill_batch_i: {batch.prefill_mini_batches[i][2]},')



            if isinstance(self.cuda_graphs, list):
                # cuda graph idx equal to min idx i in self.cuda_graphs, that self.cuda_graphs[i] > num_tokens
                cuda_graph_idx = next((i for i, token in enumerate(self.cuda_graphs) if token >= num_tokens), len(self.cuda_graphs))
                if cuda_graph_idx == len(self.cuda_graphs):
                    assert False, "num_tokens is too large"
            else:
                cuda_graph_idx = -1
    
            if self.use_cuda_graph:
                if cuda_graph_idx != -1:
                    self.input[cuda_graph_idx].fill(batch, query_manager, self.page_size)
                else:
                    self.input.fill(batch, query_manager, self.page_size)
            else:
                self.input = ForwardBatchInput(batch=batch, query_manager=query_manager, device=self.device)
                    

            if cuda_graph_idx != -1 and self.use_cuda_graph:
                self.features = self.model.batch_embeddings(self.input[cuda_graph_idx], device=self.device)
            else:
                self.features = self.model.batch_embeddings(self.input, device=self.device)
                

            self.bsz_tensor_buf.copy_(batch_size)
            self.num_tokens_tensor_buf.copy_(torch.tensor([num_tokens], dtype=torch.int32, device=self.device))

            if self.use_cuda_graph:
                if cuda_graph_idx != -1:
                    self.features_buf[cuda_graph_idx][0].copy_(self.features[0], non_blocking=True)
                else:
                    self.features_buf[0].copy_(self.features[0], non_blocking=True)
                """
                if num_tokens_0 > 64:
                    padded_num_tokens_0 = pad_num_tokens(num_tokens_0)
                    self.features_buf[0][num_tokens_0:padded_num_tokens_0] = 0
                """
            #self.input.forward_minibatchs[0].print()
            # print([[hash(k[i].float().cpu().numpy().tobytes()) for i in self.input.forward_minibatchs[0].kv_indices] for k in self.model.cache.k_caches])
            # print(f"overlap: {overlap}, is_compute_bound: {is_compute_bound}")

            # self.model.flash_infer_attn_plan(self.input, self.bsz_tensors, self.num_tokens_tensors)

            """
            if self.use_cuda_graph:
                print("before replay features_buf", self.features_buf[0])
                print("features_buf addr", self.features_buf[0].data_ptr())
            else:
                print("before run features", self.features[0])
            """
            if cuda_graph_idx != -1 and self.use_cuda_graph:
                self.model.flash_infer_attn_plan(self.input[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf,
                                            num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, 
                                                head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
                                                sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
                self.start_model_event.record(self.stream)
                page_idx, page_offset = self.model.cache.get_page_table(self.input[cuda_graph_idx].minibatch.position_ids, self.input[cuda_graph_idx].minibatch.q_indptr, self.input[cuda_graph_idx].minibatch.kv_indptr, self.input[cuda_graph_idx].minibatch.kv_indices, self.num_tokens_tensor_buf)
                if self.use_cuda_graph:
                    self.page_idx_buf[cuda_graph_idx][:num_tokens].copy_(page_idx[:num_tokens])
                    self.page_offset_buf[cuda_graph_idx][:num_tokens].copy_(page_offset[:num_tokens])
                    self.page_idx_buf[cuda_graph_idx][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1)
                    self.replay(cuda_graph_idx)
                    self.output = ForwardBatchOutput()
                    
                    self.output.top_ps.append(self.input[cuda_graph_idx].minibatch.top_ps)
                    self.output.temperatures.append(self.input[cuda_graph_idx].minibatch.temperatures)

                    self.output.logits.append(self.outputs_buf[cuda_graph_idx].logits[0][self.input[cuda_graph_idx].minibatch.logits_start].clone())
                else:
                    self.output = self.model(self.input[cuda_graph_idx], self.features, self.bsz_tensor_buf, self.num_tokens_tensor_buf, page_idx, page_offset)
                    self.output.logits[0] = self.output.logits[0][self.input[cuda_graph_idx].minibatch.logits_start]
                self.end_model_event.record(self.stream)
            else:
                self.model.flash_infer_attn_plan(self.input, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
                                            num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, 
                                                head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
                                                sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
                self.start_model_event.record(self.stream)
                page_idx, page_offset = self.model.cache.get_page_table(self.input.minibatch.position_ids, self.input.minibatch.q_indptr, self.input.minibatch.kv_indptr, self.input.minibatch.kv_indices, self.num_tokens_tensor_buf)
                if self.use_cuda_graph:
                    self.page_idx_buf[:num_tokens].copy_(page_idx[:num_tokens])
                    self.page_offset_buf[:num_tokens].copy_(page_offset[:num_tokens])
                    self.page_idx_buf[num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1) 
                    self.replay(cuda_graph_idx)
                    self.output = ForwardBatchOutput()

                    self.output.top_ps.append(self.input.minibatch.top_ps)
                    self.output.temperatures.append(self.input.minibatch.temperatures)

                    self.output.logits.append(self.outputs_buf.logits[0][self.input.minibatch.logits_start].clone())
                else:
                    self.output = self.model(self.input, self.features, self.bsz_tensor_buf, self.num_tokens_tensor_buf, page_idx, page_offset)
                    self.output.logits[0] = self.output.logits[0][self.input.minibatch.logits_start]
                    self.output.top_ps.append(self.input.minibatch.top_ps)
                    self.output.temperatures.append(self.input.minibatch.temperatures)

                self.end_model_event.record(self.stream)

        if not self.use_cuda_graph:
            self.output.num_batchs = self.input.batch_size
        else:
            self.output.num_batchs = self.input[cuda_graph_idx].batch_size


    def replay(self, cuda_graph_idx=-1):
        with torch.cuda.stream(self.stream):
            if cuda_graph_idx != -1:
                self.graphs[cuda_graph_idx].replay()
            else:
                self.graphs.replay()


    def sync(self, calc_time = True):
        self.stream.synchronize()
        if calc_time:
            self.model_time = self.start_model_event.elapsed_time(self.end_model_event)  # In ms